{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b518b04cbfe0"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "906e07f6e562"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a5620ee4049e"
      },
      "source": [
        "# 自定义 Model.fit 的内容"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0a56ffedf331"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>     <a target=\"_blank\" href=\"https://tensorflow.google.cn/guide/keras/customizing_what_happens_in_fit\"><img src=\"https://tensorflow.google.cn/images/tf_logo_32px.png\">在 TensorFlow.org 上查看</a>\n",
        "</td>\n",
        "  <td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/zh-cn/guide/keras/customizing_what_happens_in_fit.ipynb\"><img src=\"https://tensorflow.google.cn/images/colab_logo_32px.png\">在 Google Colab 中运行</a></td>\n",
        "  <td>     <a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/zh-cn/guide/keras/customizing_what_happens_in_fit.ipynb\"><img src=\"https://tensorflow.google.cn/images/GitHub-Mark-32px.png\">在 GitHub 上查看源代码</a>\n",
        "</td>\n",
        "  <td>     <a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/zh-cn/guide/keras/customizing_what_happens_in_fit.ipynb\"><img src=\"https://tensorflow.google.cn/images/download_logo_32px.png\">下载笔记本</a>\n",
        "</td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7ebb4e65ef9b"
      },
      "source": [
        "## 简介\n",
        "\n",
        "您在进行监督学习时可以使用 `fit()`，一切都可以顺利完成。\n",
        "\n",
        "需要从头开始编写自己的训练循环时，您可以使用 `GradientTape` 并控制每个微小的细节。\n",
        "\n",
        "但如果您需要自定义训练算法，又想从 `fit()` 的便捷功能（例如回调、内置分布支持或步骤融合）中受益，那么该怎么做？\n",
        "\n",
        "Keras 的核心原则是**渐进式呈现复杂性**。您应当始终能够以渐进的方式习惯较低级别的工作流。如果高级功能并不完全符合您的用例，那么您就不应深陷其中。您应当能够从容地控制微小的细节，同时保留与之相称的高级便利性。\n",
        "\n",
        "需要自定义 `fit()` 的功能时，您应**重写 `Model` 类的训练步骤函数**。此函数是 `fit()` 会针对每批次数据调用的函数。然后，您将能够像往常一样调用 `fit()`，它将运行您自己的学习算法。\n",
        "\n",
        "请注意，此模式不会妨碍您使用函数式 API 构建模型。无论是构建 `Sequential` 模型、函数式 API 模型还是子类模型，均可采用这种模式。\n",
        "\n",
        "让我们了解一下它的工作方式。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2849e371b9b6"
      },
      "source": [
        "## 设置\n",
        "\n",
        "需要 TensorFlow 2.2 或更高版本。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4dadb6688663"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "from tensorflow import keras"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9022333acaa7"
      },
      "source": [
        "## 第一个简单的示例\n",
        "\n",
        "让我们从一个简单的示例开始：\n",
        "\n",
        "- 创建一个将 `keras.Model` 子类化的新类。\n",
        "- 仅重写 `train_step(self, data)` 方法。\n",
        "- 返回一个将指标名称（包括损失）映射到其当前值的字典。\n",
        "\n",
        "输入参数 `data` 是传递以拟合训练数据的数据：\n",
        "\n",
        "- 如果通过调用 `fit(x, y, ...)` 传递 Numpy 数组，则 `data` 将为元祖 `(x, y)`。\n",
        "- 如果通过调用 `fit(dataset, ...)` 传递 `tf.data.Dataset`，则 `data` 将为每批次 `dataset` 产生的数据。\n",
        "\n",
        "我们在 `train_step` 方法的主体中实现了定期的训练更新，类似于您已经熟悉的内容。重要的是，**我们通过 `self.compiled_loss` 计算损失**，它会封装传递给 `compile()` 的损失函数。\n",
        "\n",
        "同样，我们调用 `self.compiled_metrics.update_state(y, y_pred)` 来更新在 `compile()` 中传递的指标的状态，并在最后从 `self.metrics` 中查询结果以检索其当前值。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "060c8bf4150d"
      },
      "outputs": [],
      "source": [
        "class CustomModel(keras.Model):\n",
        "    def train_step(self, data):\n",
        "        # Unpack the data. Its structure depends on your model and\n",
        "        # on what you pass to `fit()`.\n",
        "        x, y = data\n",
        "\n",
        "        with tf.GradientTape() as tape:\n",
        "            y_pred = self(x, training=True)  # Forward pass\n",
        "            # Compute the loss value\n",
        "            # (the loss function is configured in `compile()`)\n",
        "            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)\n",
        "\n",
        "        # Compute gradients\n",
        "        trainable_vars = self.trainable_variables\n",
        "        gradients = tape.gradient(loss, trainable_vars)\n",
        "        # Update weights\n",
        "        self.optimizer.apply_gradients(zip(gradients, trainable_vars))\n",
        "        # Update metrics (includes the metric that tracks the loss)\n",
        "        self.compiled_metrics.update_state(y, y_pred)\n",
        "        # Return a dict mapping metric names to current value\n",
        "        return {m.name: m.result() for m in self.metrics}\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "c9d2cc7a7014"
      },
      "source": [
        "我们来试一下："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5e6bd7b554f6"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "\n",
        "# Construct and compile an instance of CustomModel\n",
        "inputs = keras.Input(shape=(32,))\n",
        "outputs = keras.layers.Dense(1)(inputs)\n",
        "model = CustomModel(inputs, outputs)\n",
        "model.compile(optimizer=\"adam\", loss=\"mse\", metrics=[\"mae\"])\n",
        "\n",
        "# Just use `fit` as usual\n",
        "x = np.random.random((1000, 32))\n",
        "y = np.random.random((1000, 1))\n",
        "model.fit(x, y, epochs=3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a882cb6467d6"
      },
      "source": [
        "## 在更低级别上操作\n",
        "\n",
        "当然，您可以直接跳过在 `compile()` 中传递损失函数，而在 `train_step` 中*手动*完成所有内容。指标也是如此。\n",
        "\n",
        "以下是一个较低级别的示例，仅使用 `compile()` 配置优化器：\n",
        "\n",
        "- 我们从创建 `Metric` 实例以跟踪我们的损失和 MAE 得分开始。\n",
        "- 我们实现可更新这些指标状态（通过对指标调用 `update_state()`）的自定义 `train_step()` ，然后对其进行查询（通过 `result()`）以返回其当前平均值，由进度条显示并传递给任何回调。\n",
        "- 请注意，需要在每个周期之间对指标调用 `reset_states()`！否则，调用 `result()` 会返回自训练开始以来的平均值，但我们通常要使用的是每个周期的平均值。幸运的是，该框架可以帮助我们实现：只需在模型的 `metrics` 属性中列出要重置的任何指标。模型将在每个 `fit()` 周期开始时或在开始调用 `evaluate()` 时对其中列出的任何对象调用 `reset_states()`。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2308abf5fe7d"
      },
      "outputs": [],
      "source": [
        "loss_tracker = keras.metrics.Mean(name=\"loss\")\n",
        "mae_metric = keras.metrics.MeanAbsoluteError(name=\"mae\")\n",
        "\n",
        "\n",
        "class CustomModel(keras.Model):\n",
        "    def train_step(self, data):\n",
        "        x, y = data\n",
        "\n",
        "        with tf.GradientTape() as tape:\n",
        "            y_pred = self(x, training=True)  # Forward pass\n",
        "            # Compute our own loss\n",
        "            loss = keras.losses.mean_squared_error(y, y_pred)\n",
        "\n",
        "        # Compute gradients\n",
        "        trainable_vars = self.trainable_variables\n",
        "        gradients = tape.gradient(loss, trainable_vars)\n",
        "\n",
        "        # Update weights\n",
        "        self.optimizer.apply_gradients(zip(gradients, trainable_vars))\n",
        "\n",
        "        # Compute our own metrics\n",
        "        loss_tracker.update_state(loss)\n",
        "        mae_metric.update_state(y, y_pred)\n",
        "        return {\"loss\": loss_tracker.result(), \"mae\": mae_metric.result()}\n",
        "\n",
        "    @property\n",
        "    def metrics(self):\n",
        "        # We list our `Metric` objects here so that `reset_states()` can be\n",
        "        # called automatically at the start of each epoch\n",
        "        # or at the start of `evaluate()`.\n",
        "        # If you don't implement this property, you have to call\n",
        "        # `reset_states()` yourself at the time of your choosing.\n",
        "        return [loss_tracker, mae_metric]\n",
        "\n",
        "\n",
        "# Construct an instance of CustomModel\n",
        "inputs = keras.Input(shape=(32,))\n",
        "outputs = keras.layers.Dense(1)(inputs)\n",
        "model = CustomModel(inputs, outputs)\n",
        "\n",
        "# We don't passs a loss or metrics here.\n",
        "model.compile(optimizer=\"adam\")\n",
        "\n",
        "# Just use `fit` as usual -- you can use callbacks, etc.\n",
        "x = np.random.random((1000, 32))\n",
        "y = np.random.random((1000, 1))\n",
        "model.fit(x, y, epochs=5)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f451e382c6a8"
      },
      "source": [
        "## 支持 `sample_weight` 和 `class_weight`\n",
        "\n",
        "您可能已经注意到，我们的第一个基本示例并没有提及样本加权。如果要支持 `fit()` 参数 `sample_weight` 和 `class_weight`，只需执行以下操作：\n",
        "\n",
        "- 从 `data` 参数中解包 `sample_weight`\n",
        "- 将其传递给 `compiled_loss` 和 `compiled_metrics`（当然，如果您不依赖 `compile()` 来获取损失和指标，也可以手动应用）\n",
        "- 就是这么简单。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "522d7281f948"
      },
      "outputs": [],
      "source": [
        "class CustomModel(keras.Model):\n",
        "    def train_step(self, data):\n",
        "        # Unpack the data. Its structure depends on your model and\n",
        "        # on what you pass to `fit()`.\n",
        "        if len(data) == 3:\n",
        "            x, y, sample_weight = data\n",
        "        else:\n",
        "            sample_weight = None\n",
        "            x, y = data\n",
        "\n",
        "        with tf.GradientTape() as tape:\n",
        "            y_pred = self(x, training=True)  # Forward pass\n",
        "            # Compute the loss value.\n",
        "            # The loss function is configured in `compile()`.\n",
        "            loss = self.compiled_loss(\n",
        "                y,\n",
        "                y_pred,\n",
        "                sample_weight=sample_weight,\n",
        "                regularization_losses=self.losses,\n",
        "            )\n",
        "\n",
        "        # Compute gradients\n",
        "        trainable_vars = self.trainable_variables\n",
        "        gradients = tape.gradient(loss, trainable_vars)\n",
        "\n",
        "        # Update weights\n",
        "        self.optimizer.apply_gradients(zip(gradients, trainable_vars))\n",
        "\n",
        "        # Update the metrics.\n",
        "        # Metrics are configured in `compile()`.\n",
        "        self.compiled_metrics.update_state(y, y_pred, sample_weight=sample_weight)\n",
        "\n",
        "        # Return a dict mapping metric names to current value.\n",
        "        # Note that it will include the loss (tracked in self.metrics).\n",
        "        return {m.name: m.result() for m in self.metrics}\n",
        "\n",
        "\n",
        "# Construct and compile an instance of CustomModel\n",
        "inputs = keras.Input(shape=(32,))\n",
        "outputs = keras.layers.Dense(1)(inputs)\n",
        "model = CustomModel(inputs, outputs)\n",
        "model.compile(optimizer=\"adam\", loss=\"mse\", metrics=[\"mae\"])\n",
        "\n",
        "# You can now use sample_weight argument\n",
        "x = np.random.random((1000, 32))\n",
        "y = np.random.random((1000, 1))\n",
        "sw = np.random.random((1000, 1))\n",
        "model.fit(x, y, sample_weight=sw, epochs=3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "03000c5590db"
      },
      "source": [
        "## 提供您自己的评估步骤\n",
        "\n",
        "如何对调用 `model.evaluate()` 进行相同的处理？您需要以完全相同的方式重写 `test_step`。如下所示："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "999edb22c50e"
      },
      "outputs": [],
      "source": [
        "class CustomModel(keras.Model):\n",
        "    def test_step(self, data):\n",
        "        # Unpack the data\n",
        "        x, y = data\n",
        "        # Compute predictions\n",
        "        y_pred = self(x, training=False)\n",
        "        # Updates the metrics tracking the loss\n",
        "        self.compiled_loss(y, y_pred, regularization_losses=self.losses)\n",
        "        # Update the metrics.\n",
        "        self.compiled_metrics.update_state(y, y_pred)\n",
        "        # Return a dict mapping metric names to current value.\n",
        "        # Note that it will include the loss (tracked in self.metrics).\n",
        "        return {m.name: m.result() for m in self.metrics}\n",
        "\n",
        "\n",
        "# Construct an instance of CustomModel\n",
        "inputs = keras.Input(shape=(32,))\n",
        "outputs = keras.layers.Dense(1)(inputs)\n",
        "model = CustomModel(inputs, outputs)\n",
        "model.compile(loss=\"mse\", metrics=[\"mae\"])\n",
        "\n",
        "# Evaluate with our custom test_step\n",
        "x = np.random.random((1000, 32))\n",
        "y = np.random.random((1000, 1))\n",
        "model.evaluate(x, y)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9e6a662e6588"
      },
      "source": [
        "## 总结：端到端 GAN 示例\n",
        "\n",
        "让我们看一个利用您刚刚所学全部内容的端到端示例。\n",
        "\n",
        "请考虑：\n",
        "\n",
        "- 旨在生成 28x28x1 图像的生成器网络。\n",
        "- 旨在将 28x28x1 图像分为两类（“fake”和“real”）的鉴别器网络。\n",
        "- 分别用于两个网络的优化器。\n",
        "- 训练鉴别器的损失函数。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6748db01dc7c"
      },
      "outputs": [],
      "source": [
        "from tensorflow.keras import layers\n",
        "\n",
        "# Create the discriminator\n",
        "discriminator = keras.Sequential(\n",
        "    [\n",
        "        keras.Input(shape=(28, 28, 1)),\n",
        "        layers.Conv2D(64, (3, 3), strides=(2, 2), padding=\"same\"),\n",
        "        layers.LeakyReLU(alpha=0.2),\n",
        "        layers.Conv2D(128, (3, 3), strides=(2, 2), padding=\"same\"),\n",
        "        layers.LeakyReLU(alpha=0.2),\n",
        "        layers.GlobalMaxPooling2D(),\n",
        "        layers.Dense(1),\n",
        "    ],\n",
        "    name=\"discriminator\",\n",
        ")\n",
        "\n",
        "# Create the generator\n",
        "latent_dim = 128\n",
        "generator = keras.Sequential(\n",
        "    [\n",
        "        keras.Input(shape=(latent_dim,)),\n",
        "        # We want to generate 128 coefficients to reshape into a 7x7x128 map\n",
        "        layers.Dense(7 * 7 * 128),\n",
        "        layers.LeakyReLU(alpha=0.2),\n",
        "        layers.Reshape((7, 7, 128)),\n",
        "        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding=\"same\"),\n",
        "        layers.LeakyReLU(alpha=0.2),\n",
        "        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding=\"same\"),\n",
        "        layers.LeakyReLU(alpha=0.2),\n",
        "        layers.Conv2D(1, (7, 7), padding=\"same\", activation=\"sigmoid\"),\n",
        "    ],\n",
        "    name=\"generator\",\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "801e8dd0c92a"
      },
      "source": [
        "这是特征齐全的 GAN 类，重写了 `compile()` 以使用其自己的签名，并在 `train_step` 的 17 行中实现了整个 GAN 算法："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bc3fb4111393"
      },
      "outputs": [],
      "source": [
        "class GAN(keras.Model):\n",
        "    def __init__(self, discriminator, generator, latent_dim):\n",
        "        super(GAN, self).__init__()\n",
        "        self.discriminator = discriminator\n",
        "        self.generator = generator\n",
        "        self.latent_dim = latent_dim\n",
        "\n",
        "    def compile(self, d_optimizer, g_optimizer, loss_fn):\n",
        "        super(GAN, self).compile()\n",
        "        self.d_optimizer = d_optimizer\n",
        "        self.g_optimizer = g_optimizer\n",
        "        self.loss_fn = loss_fn\n",
        "\n",
        "    def train_step(self, real_images):\n",
        "        if isinstance(real_images, tuple):\n",
        "            real_images = real_images[0]\n",
        "        # Sample random points in the latent space\n",
        "        batch_size = tf.shape(real_images)[0]\n",
        "        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))\n",
        "\n",
        "        # Decode them to fake images\n",
        "        generated_images = self.generator(random_latent_vectors)\n",
        "\n",
        "        # Combine them with real images\n",
        "        combined_images = tf.concat([generated_images, real_images], axis=0)\n",
        "\n",
        "        # Assemble labels discriminating real from fake images\n",
        "        labels = tf.concat(\n",
        "            [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0\n",
        "        )\n",
        "        # Add random noise to the labels - important trick!\n",
        "        labels += 0.05 * tf.random.uniform(tf.shape(labels))\n",
        "\n",
        "        # Train the discriminator\n",
        "        with tf.GradientTape() as tape:\n",
        "            predictions = self.discriminator(combined_images)\n",
        "            d_loss = self.loss_fn(labels, predictions)\n",
        "        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)\n",
        "        self.d_optimizer.apply_gradients(\n",
        "            zip(grads, self.discriminator.trainable_weights)\n",
        "        )\n",
        "\n",
        "        # Sample random points in the latent space\n",
        "        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))\n",
        "\n",
        "        # Assemble labels that say \"all real images\"\n",
        "        misleading_labels = tf.zeros((batch_size, 1))\n",
        "\n",
        "        # Train the generator (note that we should *not* update the weights\n",
        "        # of the discriminator)!\n",
        "        with tf.GradientTape() as tape:\n",
        "            predictions = self.discriminator(self.generator(random_latent_vectors))\n",
        "            g_loss = self.loss_fn(misleading_labels, predictions)\n",
        "        grads = tape.gradient(g_loss, self.generator.trainable_weights)\n",
        "        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))\n",
        "        return {\"d_loss\": d_loss, \"g_loss\": g_loss}\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "095c499a6149"
      },
      "source": [
        "让我们对其进行测试："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "46832f2077ac"
      },
      "outputs": [],
      "source": [
        "# Prepare the dataset. We use both the training & test MNIST digits.\n",
        "batch_size = 64\n",
        "(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()\n",
        "all_digits = np.concatenate([x_train, x_test])\n",
        "all_digits = all_digits.astype(\"float32\") / 255.0\n",
        "all_digits = np.reshape(all_digits, (-1, 28, 28, 1))\n",
        "dataset = tf.data.Dataset.from_tensor_slices(all_digits)\n",
        "dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)\n",
        "\n",
        "gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)\n",
        "gan.compile(\n",
        "    d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),\n",
        "    g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),\n",
        "    loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),\n",
        ")\n",
        "\n",
        "# To limit the execution time, we only train on 100 batches. You can train on\n",
        "# the entire dataset. You will need about 20 epochs to get nice results.\n",
        "gan.fit(dataset.take(100), epochs=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2ed211016c96"
      },
      "source": [
        "深度学习背后的思想十分简单，那么实现又何必复杂呢？"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "customizing_what_happens_in_fit.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
